Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Define backends and add Triton backend for Lora #3161

Merged
merged 9 commits into from
Feb 4, 2025

Conversation

Fridge003
Copy link
Collaborator

@Fridge003 Fridge003 commented Jan 27, 2025

Motivation

Current Lora modules relies on SGemm kernels provided by flashinfer to do the computation. However, Flashinfer is not optimized well on tall and thin matrices of Lora modules. What's more, the way LoraManager that manages segment indices and weight indices of input batch is inefficient. All these issues make Lora run slowly with SGLang.

Modifications

To improve efficiency of Lora, this PR makes the following modifications on the basis of PR draft #1728:

  1. Define BaseLoraBackend, FlashInferLoraBackend and TritonLoraBackend classes, which discouple GEMM implementation of each backend from the forward logic of Lora modules. A new server arg lora-backend is added for controlling the backend.
  2. Define BatchInfo class that packs [bs, seg_lens, seg_indptr, max_len, weight_indices] together. By attaching it to lora backend, it only needs to be set once at every batch forward.
  3. Add triton kernels that can run GEMM more efficiently. Including sgemm kernel for lora a (large K, small N), sgemm kernel for lora b(large N, small K), and a fused kernel for qkv's lora_b modules.

Usage

A new argument lora-backend is added to server arguments. This argument can be either triton or flashinfer, indicating the backend to be chosen. Its default value is triton.

Accuracy Test

Accuracy test can be run with:

python test/srt/models/test_lora_backend.py

The code can pass accuracy test on both H100 and A6000 machine.

Benchmarking result

To do benchmarking for lora, run this command to launch server:

# Triton backend
python benchmark/lora/launch_server.py --max-loras-per-batch 4 --lora-backend triton

# Flashinfer backend
# python benchmark/lora/launch_server.py --max-loras-per-batch 4 --lora-backend flashinfer

# Base model without lora
# python benchmark/lora/launch_server.py --base-only

Then run this command to request test from client:

python benchmark/lora/lora_bench.py

Benchmark configurations:

  • base model: meta-llama/Llama-2-7b-hf
  • lora adapter: winddude/wizardLM-LlaMA-LoRA-7B
  • GPU: Nvidia H100
  • maximum number of serving loras: 4
  • number of requests: 50
  • input length: uniform random distribution on [1, 1024]
  • output length: uniform random distribution on [1, 128]
Backend Total Throughput (tok/s) Mean E2E Latency (ms)
Triton 2040.96 7165.67
Flashinfer 1606.97 9270.38
No Lora 3090.54 4776.19

Further Optimization

There are two main bottlenecks of Lora with current Triton backend:

  • On prefiling batches with long sequence, the lora process has to wait for prior non-lora kernels to complete, which takes a long time. I tried using multiple cuda streams, but the overhead of synchronization is much larger than the time saved.
  • Overhead of Triton's compiling process, which can only be solved by replacing Triton

The reward of autotuning is poor since sgemm on lora modules has low arithmetic intensity. The current kernels without autotuning are already fast enough.

The best way to optimize lora kernel is adding Cuda/Cutlass backend, so the time of triton compiling can be saved.

Checklist

@Fridge003 Fridge003 mentioned this pull request Jan 26, 2025
12 tasks
@Fridge003 Fridge003 changed the title [Feature] Define Gemm backends and add Triton backend for Lora [Feature] Define backends and add Triton backend for Lora Jan 27, 2025
@Fridge003 Fridge003 force-pushed the lora_triton branch 5 times, most recently from 6a6dadd to 90a5123 Compare February 1, 2025 04:02
@zhaochenyang20 zhaochenyang20 merged commit 70817a7 into sgl-project:main Feb 4, 2025
15 checks passed
@Fridge003 Fridge003 deleted the lora_triton branch February 4, 2025 06:31
@HaiShaw
Copy link
Collaborator

HaiShaw commented Feb 4, 2025

@Ying1123 we don't have flashinfer yet on ROCm, I found this merge causes a break on AMD.

@zhyncs
Copy link
Member

zhyncs commented Feb 4, 2025

@HaiShaw AMD CIs are crucial for preventing such issues from a process perspective.

@zhyncs
Copy link
Member

zhyncs commented Feb 4, 2025

@HaiShaw Also may you help fix the top of the main branch?

@HaiShaw
Copy link
Collaborator

HaiShaw commented Feb 4, 2025

@HaiShaw AMD CIs are crucial for preventing such issues from a process perspective.

Yes, let me push/press/push on it!!

@HaiShaw HaiShaw mentioned this pull request Feb 4, 2025
5 tasks
Comment on lines +218 to +223
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
lora_output = self.lora_backend.run_lora_b_sgemm(
lora_a_output,
self.B_buffer[0],
base_output=base_output,
scaling=self.scaling,
Copy link
Contributor

@Edenzzzz Edenzzzz Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, would there be a benefit in fusing these two ops?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fusion of two neighboring Gemms will be really hard to implement, and its benefit is uncertain.

@zhaochenyang20
Copy link
Collaborator

@Fridge003 Please check these.

@zhaochenyang20
Copy link
Collaborator

I think it has been fixed from AMD people.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants